{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from torchtext import data, datasets\n",
    "import pandas as pd\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dir = os.getcwd()\n",
    "atis_data = os.path.join(base_dir, 'atis')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "    构建训练集与验证集\n",
    "'''\n",
    "def build_dataset():\n",
    "    \n",
    "    tokenize = lambda s:s.split()\n",
    "    \n",
    "    SOURCE = data.Field(sequential=True, tokenize=tokenize,\n",
    "                        lower=True, use_vocab=True,\n",
    "                        init_token='<sos>', eos_token='<eos>',\n",
    "                        pad_token='<pad>', unk_token='<unk>',\n",
    "                        batch_first=True, fix_length=50,\n",
    "                        include_lengths=True) #include_lengths=True为方便之后使用torch的pack_padded_sequence\n",
    "    \n",
    "    TARGET = data.Field(sequential=True, tokenize=tokenize,\n",
    "                        lower=True, use_vocab=True,\n",
    "                        init_token='<sos>', eos_token='<eos>',\n",
    "                        pad_token='<pad>', unk_token='<unk>',\n",
    "                        batch_first=True, fix_length=50,\n",
    "                        include_lengths=True) #include_lengths=True为方便之后使用torch的pack_padded_sequence\n",
    "    LABEL = data.Field(\n",
    "                    sequential=False,\n",
    "                    use_vocab=True)\n",
    "    \n",
    "    train, val = data.TabularDataset.splits(\n",
    "                                            path=atis_data,\n",
    "                                            skip_header=True,\n",
    "                                            train='atis.train.csv',\n",
    "                                            validation='atis.test.csv',\n",
    "                                            format='csv',\n",
    "                                            fields=[('index', None), ('intent', LABEL), ('source', SOURCE), ('target', TARGET)])\n",
    "    print('train data info:')\n",
    "    print(len(train))\n",
    "    print(vars(train[0]))\n",
    "    print('val data info:')\n",
    "    print(len(val))\n",
    "    print(vars(val[0]))\n",
    "    \n",
    "    SOURCE.build_vocab(train, val)\n",
    "    TARGET.build_vocab(train, val)\n",
    "    LABEL.build_vocab(train, val)\n",
    "    \n",
    "    print('vocab info:')\n",
    "    print('source vocab size:{}'.format(len(SOURCE.vocab)))\n",
    "    print('target vocab size:{}'.format(len(TARGET.vocab)))\n",
    "    print('label vocab size:{}'.format(len(LABEL.vocab)))\n",
    "    \n",
    "    \n",
    "    #train_iter, val_iter = data.BucketIterator.splits(\n",
    "    #                                                (train, val),\n",
    "    #                                                batch_sizes=(128, len(val)),\n",
    "    #                                                #shuffle=True,\n",
    "    #                                                sort_within_batch=True, #为true则一个batch内的数据会按sort_key规则降序排序\n",
    "    #                                                sort_key=lambda x: len(x.source)) #这里按src的长度降序排序，主要是为后面pack,pad操作)\n",
    "\n",
    "    train_iter, val_iter = data.Iterator.splits(\n",
    "                                                (train, val),\n",
    "                                                batch_sizes=(128, len(val)), # 训练集设置为128,验证集整个集合用于测试\n",
    "                                                shuffle=True,\n",
    "                                                sort_within_batch=True, #为true则一个batch内的数据会按sort_key规则降序排序\n",
    "                                                sort_key=lambda x: len(x.source)) #这里按src的长度降序排序，主要是为后面pack,pad操作)\n",
    "    \n",
    "    return train_iter, val_iter\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train data info:\n",
      "4978\n",
      "{'intent': 'flight', 'source': ['i', 'want', 'to', 'fly', 'from', 'boston', 'at', '838', 'am', 'and', 'arrive', 'in', 'denver', 'at', '1110', 'in', 'the', 'morning'], 'target': ['o', 'o', 'o', 'o', 'o', 'b-fromloc.city_name', 'o', 'b-depart_time.time', 'i-depart_time.time', 'o', 'o', 'o', 'b-toloc.city_name', 'o', 'b-arrive_time.time', 'o', 'o', 'b-arrive_time.period_of_day']}\n",
      "val data info:\n",
      "893\n",
      "{'intent': 'flight', 'source': ['i', 'would', 'like', 'to', 'find', 'a', 'flight', 'from', 'charlotte', 'to', 'las', 'vegas', 'that', 'makes', 'a', 'stop', 'in', 'st.', 'louis'], 'target': ['o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'b-fromloc.city_name', 'o', 'b-toloc.city_name', 'i-toloc.city_name', 'o', 'o', 'o', 'o', 'o', 'b-stoploc.city_name', 'i-stoploc.city_name']}\n",
      "vocab info:\n",
      "source vocab size:945\n",
      "target vocab size:133\n",
      "label vocab size:27\n",
      "train_iter size:39\n",
      "val_iter size:1\n"
     ]
    }
   ],
   "source": [
    "train_iter, val_iter = build_dataset()\n",
    "print('train_iter size:{}'.format(len(train_iter)))\n",
    "print('val_iter size:{}'.format(len(val_iter)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 1,  1,  1,  1,  2,  1,  1,  6, 26,  2,  1,  1,  1,  1,  1,  1,  1,  1,\n",
      "         1,  2,  1,  1,  2,  1,  1,  1,  2,  1,  1,  1,  1,  4,  1,  1,  1,  1,\n",
      "         6,  1,  2,  1,  1,  1,  7,  2,  1,  1,  1,  1,  1,  1,  4,  1,  1,  1,\n",
      "         1,  1,  2,  1,  1,  7,  3,  4,  8,  1,  1,  1,  1,  1, 11,  1,  1,  1,\n",
      "         2,  1,  1,  2,  1,  1,  1,  1,  1,  1,  1,  8,  1,  2,  1,  2,  1,  1,\n",
      "         1,  1,  1,  1,  8,  1,  4,  3,  1,  1,  1,  1,  2,  1,  1,  1,  1,  1,\n",
      "         1,  4,  1,  1,  1,  3,  1,  1,  2,  1, 18,  1,  3,  1, 10,  1,  1,  1,\n",
      "         1,  1])\n",
      "(tensor([[  2,  13,  81,  ...,   1,   1,   1],\n",
      "        [  2,  13,  40,  ...,   1,   1,   1],\n",
      "        [  2,  13, 189,  ...,   1,   1,   1],\n",
      "        ...,\n",
      "        [  2,  38,  11,  ...,   1,   1,   1],\n",
      "        [  2,   6,   5,  ...,   1,   1,   1],\n",
      "        [  2,   6,   5,  ...,   1,   1,   1]]), tensor([24, 23, 23, 20, 20, 19, 19, 19, 19, 18, 18, 18, 18, 18, 18, 18, 17, 17,\n",
      "        17, 17, 17, 17, 17, 17, 17, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 15,\n",
      "        15, 15, 15, 15, 15, 15, 15, 15, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,\n",
      "        14, 14, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,\n",
      "        13, 13, 13, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 11,\n",
      "        11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 10, 10, 10, 10,\n",
      "        10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,  9,  9,  9,  9,  9,  8,\n",
      "         8,  7]))\n",
      "(tensor([[ 2,  4,  4,  ...,  1,  1,  1],\n",
      "        [ 2,  4,  4,  ...,  1,  1,  1],\n",
      "        [ 2,  4,  4,  ...,  1,  1,  1],\n",
      "        ...,\n",
      "        [ 2, 11,  4,  ...,  1,  1,  1],\n",
      "        [ 2,  4,  4,  ...,  1,  1,  1],\n",
      "        [ 2,  4,  4,  ...,  1,  1,  1]]), tensor([24, 23, 23, 20, 20, 19, 19, 19, 19, 18, 18, 18, 18, 18, 18, 18, 17, 17,\n",
      "        17, 17, 17, 17, 17, 17, 17, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 15,\n",
      "        15, 15, 15, 15, 15, 15, 15, 15, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,\n",
      "        14, 14, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,\n",
      "        13, 13, 13, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 11,\n",
      "        11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 10, 10, 10, 10,\n",
      "        10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,  9,  9,  9,  9,  9,  8,\n",
      "         8,  7]))\n"
     ]
    }
   ],
   "source": [
    "for i,batch in enumerate(train_iter):\n",
    "    print(batch.intent)\n",
    "    print(batch.source)\n",
    "    print(batch.target)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
